#!/usr/bin/python3
""" Quantify related steps.
1. Quantify umodel
2. Convert dataset
"""
from __future__ import print_function
import os
import shutil
import argparse

def parse_args():
    """argument parser"""
    parser = argparse.ArgumentParser(
            description="quantify tools")
    parser.add_argument('--mode', type=str,
                        choices=['quantify', 'convert'],
                        help="mode: quantify umodel | convert dataset")
    parser.add_argument('--model_path', type=str,
                        help="fp32 model prototxt")
    parser.add_argument('--weight_path', type=str,
                        help="fp32umodel ")
    parser.add_argument('--iters', type=str,
                        help="calibration iterations")
    parser.add_argument('--src_path', type=str,
                        help="input source dataset path")
    parser.add_argument('--list_path', type=str,
                        help="dataset file list path")
    parser.add_argument('--lmdb_path', type=str,
                        help="lmdb dataset path")
    args = parser.parse_args()
    return args

def tolmdb(dataset, filelist, lmdbpath):
    """ Convert image data to lmdb format

    Parameters
    ----------
    dataset : str
        input dataset path
    filelist : str
        dataset file path
    lmdbpath : str
        output lmdb path
    """
    if os.path.isdir(lmdbpath):
        shutil.rmtree(lmdbpath)
    shell_str = "convert_imageset --shuffle" + " " + dataset + " " \
                + filelist + " " + lmdbpath
    os.system(shell_str)

def int8quantify(model_path, weight_path, lmdb_path, iters):
    """ quantify fp32umodel with lmdb data

    Parameters
    ----------
    model_path : str
        fp32 model prototxt
    weight_path : str
        fp32umodel path
    lmdb_path : str
        input calibration dataset
    iters : str
        calibration iterations
    """
    shell_str = "calibration_use_pb release" + " --model=" + model_path \
                + " --weights=" + weight_path + " --iterations=" + iters \
                + " --bitwidth=TO_INT8"
    os.system(shell_str)

if __name__ == '__main__':
    args = parse_args()
    if args.mode == "quantify":
        int8quantify(args.model_path, args.weight_path, args.lmdb_path, args.iters)
    elif args.mode == "convert":
        tolmdb(args.src_path, args.list_path, args.lmdb_path)
    else:
        # raise ValueError("not a valid mode {}".format(args.mode))
        print("mode -> quantify fp32model")
        print(" "*6 + "--mode quantify\n" + 
              " "*6 + "--model_path pbtxt path\n" +
              " "*6 + "--weight_path fp32umodel path\n" +
              " "*6 + "--lmdb_path calibration dataset path\n" +
              " "*6 + "--iters calibration iteration")
        print("mode -> convert image dataset")
        print(" "*6 + "--mode convert\n" + 
              " "*6 + "--src_path dataset_path\ \n" +
              " "*6 + "--list_path dataset file list path\n" +
              " "*6 + "--lmdb_path output lmdb dataset path\n")

